archived/Text_Classification_BERT/scripts/train_bert.py (207 lines of code) (raw):

import argparse import logging import os from typing import Any import pandas as pd import time import torch import numpy as np import pytorch_lightning as pl from sklearn.model_selection import train_test_split from torch.optim import AdamW from torch.utils.data import DataLoader, TensorDataset from transformers import BertForSequenceClassification, BertTokenizer, get_linear_schedule_with_warmup ### implementation - optional from smart_sifting.data_model.data_model_interface import SiftingBatch, SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch t #### interface from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from smart_sifting.loss.abstract_sift_loss_module import Loss from smart_sifting.metrics.lightning import TrainingMetricsRecorder """ This is the sagemaker entrypoint script that trains a BERT model on a public CoLA dataset using PyTorch Lightning modules. Note that this script should only be used as a Sagemaker entry point. It assumes the data is saved in an S3 location thats specified by the "SM_CHANNEL_DATA" environment variable set by Sagemaker. It also assumes that there are the following parameters passed in as arguments: num_nodes - the number of nodes, or instances, that the training script should run on dev_mode - flag determining whether or not to turn on dev mode. This is useful for debugging runs. See run_bert_ptl.py for the Sagemaker PyTorch Estimator that sets these parameters. The data processing and training code is adapted from https://www.kaggle.com/code/hassanamin/bert-pytorch-cola-classification/notebook """ RANDOM_SEED = 7 # Setting up logger for this module logger = logging.getLogger(__name__) class BertLoss(Loss): """ This is an implementation of the Loss interface for the BERT model required for Smart Sifting. Use Cross-Entropy loss with 2 classes """ def __init__(self): self.celoss = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: # get original batch onto model device. Note that we are assuming the model is on the right device here already # Pytorch lightning takes care of this under the hood with the model thats passed in. # TODO: ensure batch and model are on the same device in SiftDataloader so that the customer # doesn't have to implement this device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # compute loss outputs = model(batch) return self.celoss(outputs.logits, batch[2]) class BertListBatchTransform(SiftingBatchTransform): """ This is an implementation of the data transforms for the BERT model required for Smart Sifting. Transform to and from ListBatch """ def transform(self, batch: Any): inputs = [] for i in range(len(batch[0])): inputs.append((batch[0][i], batch[1][i])) labels = batch[2].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): inputs = list_batch.inputs input_ids = [iid for (iid, _) in inputs] masks = [mask for (_, mask) in inputs] a_batch = [torch.stack(input_ids), torch.stack(masks), torch.tensor(list_batch.labels, dtype=torch.int64)] return a_batch class ColaDataModule(pl.LightningDataModule): def __init__(self, batch_size: int, model: torch.nn.Module, log_batches: bool): super().__init__() self.batch_size = batch_size self.model = model self.log_batches = log_batches def setup(self, stage: str) -> None: """ Loads the data from s3, splits it into multi-batches. This logic is dataset specific. """ logger.info(f"Preprocessing CoLA dataset") # The environment variable for the path to find the dataset is SM_CHANNEL_{channel_name}, which should match # with the channel name specified in the estimator in `run_bert_ptl.py` data_path = os.environ["SM_CHANNEL_DATA"] dataframe = pd.read_csv( f"{data_path}/train.tsv", sep="\t" ) # Split dataframes (Note: we use scikitlearn here because pytorch random_split doesn't work as intended - theres # a bug when we pass in proportions to random_split (https://stackoverflow.com/questions/74327447/how-to-use-random-split-with-percentage-split-sum-of-input-lengths-does-not-equ) # and we get a KeyError when iterating through the resulting split datasets) logger.info(f"Splitting dataframes into train, val, and test") train_df, test_df = train_test_split(dataframe, train_size=0.9, random_state=RANDOM_SEED) train_df, val_df = train_test_split(train_df, train_size=0.9, random_state=RANDOM_SEED) # Finally, transform the dataframes into PyTorch datasets logger.info(f"Transforming dataframes into datasets") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) max_sentence_length = 128 self.train = self._transform_to_dataset(train_df, tokenizer, max_sentence_length) self.val = self._transform_to_dataset(val_df, tokenizer, max_sentence_length) self.test = self._transform_to_dataset(test_df, tokenizer, max_sentence_length) logger.info("Done preprocessing CoLA dataset") def train_dataloader(self): sift_config = RelativeProbabilisticSiftConfig( beta_value=3, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=10) ) ) return SiftingDataloader( sift_config = sift_config, orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True), loss_impl=BertLoss(), model=self.model, batch_transforms=BertListBatchTransform() ) def val_dataloader(self): return DataLoader(self.val, self.batch_size) def test_dataloader(self): return DataLoader(self.test, self.batch_size) def predict_dataloader(self): return DataLoader(self.test, self.batch_size) def _transform_to_dataset(self, dataframe: pd.DataFrame, tokenizer, max_sentence_length): sentences = dataframe.sentence.values labels = dataframe.label.values input_ids = [] for sent in sentences: encoded_sent = tokenizer.encode(sent, add_special_tokens=True) input_ids.append(encoded_sent) # pad shorter sentences input_ids_padded = [] for i in input_ids: while len(i) < max_sentence_length: i.append(0) input_ids_padded.append(i) input_ids = input_ids_padded # mask; 0: added, 1: otherwise attention_masks = [] # For each sentence... for sent in input_ids: att_mask = [int(token_id > 0) for token_id in sent] attention_masks.append(att_mask) # convert to PyTorch data types. inputs = torch.tensor(input_ids) labels = torch.tensor(labels) masks = torch.tensor(attention_masks) return TensorDataset(inputs, masks, labels) class BertLitModule(pl.LightningModule): def __init__(self): super().__init__() self.save_hyperparameters() self.model = self._create_model() self.celoss = torch.nn.CrossEntropyLoss() def forward(self, batch): return self.model(batch[0], token_type_ids=None, attention_mask=batch[1]) def training_step(self, batch, batch_idx): # Forward Pass outputs = self(batch) loss = self.celoss(outputs.logits, batch[2]) self.log("train_loss", loss) return loss def evaluate(self, batch, stage=None): outputs = self(batch) # Move tensors to CPU logits = outputs.logits.detach().cpu().numpy() label_ids = batch[2].to('cpu').numpy() # compute accuracy acc = self._flat_accuracy(logits, label_ids) if stage: self.log(f"{stage}_acc", acc, prog_bar=True) def validation_step(self, batch, batch_idx): self.evaluate(batch, "val") def test_step(self, batch, batch_idx): self.evaluate(batch, "test") def configure_optimizers(self): logger.info("Initializing AdamW optimizer") optimizer = AdamW( self.model.parameters(), lr=2e-5, # args.learning_rate - default is 5e-5, this script has 2e-5 eps=1e-8, # args.adam_epsilon - default is 1e-8. ) # Create the learning rate scheduler. logger.info("Initializing learning rate scheduler") scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, # Default value in run_glue.py num_training_steps=self.trainer.estimated_stepping_batches) return [optimizer], [scheduler] def _create_model(self): logger.info("Creating BertForSequenceClassification") model = BertForSequenceClassification.from_pretrained( "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. num_labels=2, # The number of output labels--2 for binary classification. output_attentions=False, # Whether the model returns attentions weights. output_hidden_states=False, # Whether the model returns all hidden-states. ) return model # Function to calculate the accuracy of our predictions vs labels def _flat_accuracy(self, preds, labels): pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() return np.sum(pred_flat == labels_flat) / len(labels_flat) def main(args): # Setting up logger config logging.basicConfig( format='%(asctime)s.%(msecs)03d PID:%(process)d %(levelname)s %(module)s - %(funcName)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=args.log_level, force=True ) pl.seed_everything(RANDOM_SEED) model = BertLitModule() # For fine-tuning BERT on a specific task, the authors recommend a batch size of 16 or 32. data = ColaDataModule(batch_size=32, model=model, log_batches=args.log_batches) trainer = pl.Trainer( # Authors recommend 2 - 4 max_epochs=args.epochs, accelerator="auto", strategy="ddp", num_nodes=args.num_nodes, # Clip the norm of the gradients to 1.0. This is to help prevent the "exploding gradients" problem. gradient_clip_val=1, callbacks=[TrainingMetricsRecorder()], ) trainer.fit(model, data) # Converts string argument to boolean def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') if __name__ == "__main__": start_time = time.perf_counter() parser = argparse.ArgumentParser() parser.add_argument( "--num_nodes", type=int, default=1, metavar="N", help="number of training nodes (default: 1)" ) parser.add_argument( "--epochs", type=int, default=2, metavar="N", help="number of training epochs (default: 2)" ) parser.add_argument( "--log_level", type=int, default=1, metavar="N", help="log level (default: 1)" ) parser.add_argument( "--log_batches", type=str2bool, nargs="?", const=True, default=False, metavar="N", help="whether or not to log batches (default: False)" ) args = parser.parse_args() main(args) logger.info(f"Total time : {time.perf_counter() - start_time}")